
# torch and torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR

# ------------------------------------------------------------------------------
#   Define loss function
# --------------------------------------------------------------------
def define_loss_function(lossname):
    # cross-entropy loss
    if 'cross-entropy' == lossname:
        return F.cross_entropy

    # Undefined loss functions
    else:
        assert False, ('Error: invalid loss function name [{}]'.format(lossname))


# ------------------------------------------------------------------------------
#   Define optimizer
# ------------------------------------------------------------------------------
def define_optimizer(net, opname, lrate, steps):

    # Stochastic Gradient Descent
    if opname == 'SGD':
        optimizer = optim.SGD( \
            net.parameters(), lr=lrate, momentum=0.9, weight_decay=1e-04)
        scheduler = MultiStepLR(optimizer, milestones=steps, gamma=0.5)

    # Adam
    elif opname == 'Adam':
        optimizer = optim.Adam( \
            net.parameters(), lr=lrate, betas=(0.5, 0.999), weight_decay=(1e-1), amsgrad=True)
        scheduler = None

    # Adadelta
    elif opname == 'Adadelta':
        optimizer = optim.Adadelta(\
            net.parameters(), lr=lrate, weight_decay=1e-4)
        scheduler = MultiStepLR(optimizer, milestones=steps, gamma=0.5)

    # RMSProp
    elif opname == 'Rmsprop':
        optimizer = optim.RMSprop(\
            net.parameters(), lr=lrate)
        scheduler = MultiStepLR(optimizer, milestones=steps, gamma=0.5)
    

    # undefined
    else:
        assert False, ('Error: undefined optimizer [{}]'.format(opname))

    return optimizer, scheduler
